#%%
import torch 
import tensorly as tly

import collections

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import tltorch
import math
import collections

from .matrix_conv import Conv2d_USV

class Flatten(torch.nn.Module):
    def forward(self, input):
        batch_size = input.size(0)
        out = input.contiguous().view(batch_size, -1)
        return out
    

class AlexNet(torch.nn.Module):
    def __init__(self, output_dim,device = 'cpu',args = None):
        super().__init__()
        self.device = device
        self.args = args
        if args.deco == 'cp':
            self.lr_model = torch.nn.Sequential(
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(in_channels = 3,out_channels = 64,kernel_size= 3,stride =  2, padding = 1,bias = False), rank=1-args.tau, decompose_weights=False, factorization='cp'),
                torch.nn.BatchNorm2d(64,momentum=0.9),
                torch.nn.MaxPool2d(2), 
                torch.nn.ReLU(),
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(64, 192, 3, padding=1,bias = False), rank=1-args.tau, decompose_weights=False, factorization='cp'),
                torch.nn.BatchNorm2d(192,momentum=0.9),
                torch.nn.MaxPool2d(2),
                torch.nn.ReLU(),
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(192, 384, 3, padding=1,bias = False), rank=1-args.tau, decompose_weights=False, factorization='cp'),
                torch.nn.BatchNorm2d(384,momentum=0.9),
                torch.nn.ReLU(),
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(384, 256, 3, padding=1), rank=1-args.tau, decompose_weights=False, factorization='cp'),
                torch.nn.BatchNorm2d(256,momentum=0.9),
                torch.nn.ReLU(),
                tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(256, 256, 3, padding=1,bias = False), rank=1-args.tau, decompose_weights=False, factorization='cp'),
                torch.nn.BatchNorm2d(256,momentum=0.9),
                torch.nn.MaxPool2d(2),
                torch.nn.ReLU(),
                Flatten(),
                torch.nn.Linear(256 * 2 * 2, 256),
                torch.nn.ReLU(),
                # torch.nn.Dropout(0.2),
                torch.nn.Linear(256, output_dim)
            )
        elif args.deco == 'tucker':
            self.lr_model = torch.nn.Sequential(
            tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(in_channels = 3,out_channels = 64,kernel_size= 3,stride =  2, padding = 1,bias = False), rank=1-args.tau, decompose_weights=False, factorization='tucker'),
            torch.nn.BatchNorm2d(64,momentum=0.9),
            torch.nn.MaxPool2d(2),  
            torch.nn.ReLU(),
            tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(64, 192, 3, padding=1,bias = False), rank=1-args.tau, decompose_weights=False, factorization='tucker'),
            torch.nn.BatchNorm2d(192,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(192, 384, 3, padding=1,bias = False), rank=1-args.tau, decompose_weights=False, factorization='tucker'),
            torch.nn.BatchNorm2d(384,momentum=0.9),
            torch.nn.ReLU(),
            tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(384, 256, 3, padding=1), rank=1-args.tau, decompose_weights=False, factorization='tucker'),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.ReLU(),
            tltorch.FactorizedConv.from_conv(torch.nn.Conv2d(256, 256, 3, padding=1,bias = False), rank=1-args.tau, decompose_weights=False, factorization='tucker'),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            Flatten(),
            torch.nn.Linear(256 * 2 * 2, 256),
            torch.nn.ReLU(),
            # torch.nn.Dropout(0.2),
            torch.nn.Linear(256, output_dim)
        )
            
        elif args.deco == 'mat':
            ranks = [27,192,384,256,256]
            ranks = [math.ceil((1-args.tau)*r) for r in ranks]
            self.lr_model = torch.nn.Sequential(
            Conv2d_USV(in_channels = 3,out_channels = 64,kernel_size= 3,stride =  2, padding = 1,bias = False,rank = ranks[0]),
            torch.nn.BatchNorm2d(64,momentum=0.9),
            torch.nn.MaxPool2d(2),  
            torch.nn.ReLU(),
            Conv2d_USV(64, 192, 3, padding=1,bias = False,rank = ranks[1]),
            torch.nn.BatchNorm2d(192,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            Conv2d_USV(192, 384, 3, padding=1,bias = False,rank = ranks[2]),
            torch.nn.BatchNorm2d(384,momentum=0.9),
            torch.nn.ReLU(),
            Conv2d_USV(384, 256, 3, padding=1,rank = ranks[3]),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.ReLU(),
            Conv2d_USV(256, 256, 3, padding=1,bias = False,rank = ranks[4]),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            Flatten(),
            torch.nn.Linear(256 * 2 * 2, 256),
            torch.nn.ReLU(),
            # torch.nn.Dropout(0.2),
            torch.nn.Linear(256, output_dim)
        )

    def forward(self, x):
        
        return self.lr_model(x)
# %%
